/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/core/Expressions.h"
#include "velox/exec/PlanNodeStats.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/HiveConnectorTestBase.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/SumNonPODAggregate.h"
#include "velox/exec/tests/utils/TempFilePath.h"

namespace facebook::velox::exec {
namespace {

using namespace facebook::velox::exec::test;

struct TestParams {
  int32_t streamingMinOutputBatchSize;
  uint64_t preferredOutputBatchBytes;
};

class StreamingAggregationTest
    : public HiveConnectorTestBase,
      public testing::WithParamInterface<TestParams> {
 protected:
  void SetUp() override {
    HiveConnectorTestBase::SetUp();
    registerSumNonPODAggregate("sumnonpod", 64);
  }

  int32_t flushRows() {
    return GetParam().streamingMinOutputBatchSize;
  }

  uint64_t preferredOutputBatchBytes() {
    return GetParam().preferredOutputBatchBytes;
  }

  AssertQueryBuilder& config(
      AssertQueryBuilder builder,
      uint32_t outputBatchSize) {
    return builder
        .config(
            core::QueryConfig::kPreferredOutputBatchRows,
            std::to_string(outputBatchSize))
        .config(
            core::QueryConfig::kStreamingAggregationMinOutputBatchRows,
            std::to_string(flushRows()))
        .config(
            core::QueryConfig::kPreferredOutputBatchBytes,
            std::to_string(preferredOutputBatchBytes()));
  }

  void testAggregation(
      const std::vector<VectorPtr>& keys,
      uint32_t outputBatchSize) {
    auto data = addPayload(keys, 1);
    createDuckDbTable(data);

    auto plan = PlanBuilder()
                    .values(data)
                    .partialStreamingAggregation(
                        {"c0"},
                        {"count(1)",
                         "min(c1)",
                         "max(c1)",
                         "sum(c1)",
                         "sumnonpod(1)",
                         "sum(cast(NULL as INT))",
                         "approx_percentile(c1, 0.95)"})
                    .finalAggregation()
                    .planNode();

    config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
        .assertResults(
            "SELECT c0, count(1), min(c1), max(c1), sum(c1), sum(1), sum(cast(NULL as INT))"
            "     , approx_quantile(c1, 0.95) "
            "FROM tmp GROUP BY 1");

    EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);

    plan =
        PlanBuilder()
            .values(data)
            .project({"c1", "c0"})
            .partialStreamingAggregation(
                {"c0"},
                {"count(1)", "min(c1)", "max(c1)", "sum(c1)", "sumnonpod(1)"})
            .finalAggregation()
            .planNode();

    config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
        .assertResults(
            "SELECT c0, count(1), min(c1), max(c1), sum(c1), sum(1) FROM tmp GROUP BY 1");

    EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);

    // Test aggregation masks: one aggregate without a mask, two with the same
    // mask, one with a different mask.
    plan = PlanBuilder()
               .values(data)
               .project({"c0", "c1", "c1 % 7 = 0 AS m1", "c1 % 11 = 0 AS m2"})
               .partialStreamingAggregation(
                   {"c0"},
                   {"count(1)", "min(c1)", "max(c1)", "sum(c1)"},
                   {"", "m1", "m2", "m1"})
               .finalAggregation()
               .planNode();

    config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
        .assertResults(
            "SELECT c0, count(1), min(c1) filter (where c1 % 7 = 0), "
            "max(c1) filter (where c1 % 11 = 0), sum(c1) filter (where c1 % 7 = 0) "
            "FROM tmp GROUP BY 1");
  }

  void testSortedAggregation(
      const std::vector<VectorPtr>& keys,
      uint32_t outputBatchSize) {
    auto data = addPayload(keys, 2);
    createDuckDbTable(data);

    auto plan = PlanBuilder()
                    .values(data)
                    .streamingAggregation(
                        {"c0"},
                        {"max(c1 order by c2)",
                         "max(c1 order by c2 desc)",
                         "array_agg(c1 order by c2)"},
                        {},
                        core::AggregationNode::Step::kSingle,
                        false)
                    .planNode();

    config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
        .assertResults(
            "SELECT c0, max(c1 order by c2), max(c1 order by c2 desc), array_agg(c1 order by c2) FROM tmp GROUP BY c0");
  }

  void testSortedAggregationWithBarrier(
      const std::vector<VectorPtr>& keys,
      uint32_t outputBatchSize,
      uint32_t expectedNumOuputBatches) {
    const auto inputVectors = addPayload(keys, 2);

    std::vector<std::shared_ptr<TempFilePath>> tempFiles;
    const int numSplits = keys.size();
    for (int32_t i = 0; i < numSplits; ++i) {
      tempFiles.push_back(TempFilePath::create());
    }
    writeToFiles(toFilePaths(tempFiles), inputVectors);

    createDuckDbTable(inputVectors);

    auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
    core::PlanNodeId aggregationNodeId;
    auto plan = PlanBuilder(planNodeIdGenerator)
                    .startTableScan()
                    .outputType(
                        std::dynamic_pointer_cast<const RowType>(
                            inputVectors[0]->type()))
                    .endTableScan()
                    .streamingAggregation(
                        {"c0"},
                        {"max(c1 order by c2)",
                         "max(c1 order by c2 desc)",
                         "array_agg(c1 order by c2)"},
                        {},
                        core::AggregationNode::Step::kSingle,
                        false)
                    .capturePlanNodeId(aggregationNodeId)
                    .planNode();

    for (const auto barrierExecution : {false, true}) {
      SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution));
      auto task =
          AssertQueryBuilder(plan, duckDbQueryRunner_)
              .splits(makeHiveConnectorSplits(tempFiles))
              .serialExecution(true)
              .barrierExecution(barrierExecution)
              .config(
                  core::QueryConfig::kPreferredOutputBatchRows,
                  std::to_string(outputBatchSize))
              .assertResults(
                  "SELECT c0, max(c1 order by c2), max(c1 order by c2 desc), array_agg(c1 order by c2) FROM tmp GROUP BY c0");
      const auto taskStats = task->taskStats();
      ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0);
      ASSERT_EQ(taskStats.numFinishedSplits, numSplits);
      ASSERT_EQ(
          velox::exec::toPlanStats(taskStats)
              .at(aggregationNodeId)
              .outputVectors,
          expectedNumOuputBatches);
    }
  }

  void testDistinctAggregation(
      const std::vector<VectorPtr>& keys,
      uint32_t outputBatchSize) {
    auto data = addPayload(keys, 2);
    createDuckDbTable(data);

    {
      auto plan = PlanBuilder()
                      .values(data)
                      .streamingAggregation(
                          {"c0"},
                          {"array_agg(distinct c1)",
                           "array_agg(c1 order by c2)",
                           "count(distinct c1)",
                           "array_agg(c2)"},
                          {},
                          core::AggregationNode::Step::kSingle,
                          false)
                      .planNode();

      config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
          .assertResults(
              "SELECT c0, array_agg(distinct c1), array_agg(c1 order by c2), "
              "count(distinct c1), array_agg(c2) FROM tmp GROUP BY c0");
    }

    {
      auto plan =
          PlanBuilder()
              .values(data)
              .streamingAggregation(
                  {"c0"}, {}, {}, core::AggregationNode::Step::kSingle, false)
              .planNode();

      config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
          .assertResults("SELECT distinct c0 FROM tmp");
    }
  }

  void testDistinctAggregationWithBarrier(
      const std::vector<VectorPtr>& keys,
      uint32_t outputBatchSize,
      uint32_t expectedNumOuputBatches) {
    const auto inputVectors = addPayload(keys, 2);
    std::vector<std::shared_ptr<TempFilePath>> tempFiles;
    const int numSplits = keys.size();
    for (int32_t i = 0; i < numSplits; ++i) {
      tempFiles.push_back(TempFilePath::create());
    }
    writeToFiles(toFilePaths(tempFiles), inputVectors);

    createDuckDbTable(inputVectors);

    auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
    {
      core::PlanNodeId aggregationNodeId;
      auto plan = PlanBuilder(planNodeIdGenerator)
                      .startTableScan()
                      .outputType(
                          std::dynamic_pointer_cast<const RowType>(
                              inputVectors[0]->type()))
                      .endTableScan()
                      .streamingAggregation(
                          {"c0"},
                          {"array_agg(distinct c1)",
                           "array_agg(c1 order by c2)",
                           "count(distinct c1)",
                           "array_agg(c2)"},
                          {},
                          core::AggregationNode::Step::kSingle,
                          false)
                      .capturePlanNodeId(aggregationNodeId)
                      .planNode();
      for (const auto barrierExecution : {false, true}) {
        SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution));
        auto task =
            AssertQueryBuilder(plan, duckDbQueryRunner_)
                .splits(makeHiveConnectorSplits(tempFiles))
                .serialExecution(true)
                .barrierExecution(barrierExecution)
                .config(
                    core::QueryConfig::kPreferredOutputBatchRows,
                    std::to_string(outputBatchSize))
                .assertResults(
                    "SELECT c0, array_agg(distinct c1), array_agg(c1 order by c2), "
                    "count(distinct c1), array_agg(c2) FROM tmp GROUP BY c0");
        const auto taskStats = task->taskStats();
        ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0);
        ASSERT_EQ(taskStats.numFinishedSplits, numSplits);
        ASSERT_EQ(
            velox::exec::toPlanStats(taskStats)
                .at(aggregationNodeId)
                .outputVectors,
            expectedNumOuputBatches);
      }
    }

    {
      core::PlanNodeId aggregationNodeId;
      auto plan =
          PlanBuilder(planNodeIdGenerator)
              .startTableScan()
              .outputType(
                  std::dynamic_pointer_cast<const RowType>(
                      inputVectors[0]->type()))
              .endTableScan()
              .streamingAggregation(
                  {"c0"}, {}, {}, core::AggregationNode::Step::kSingle, false)
              .capturePlanNodeId(aggregationNodeId)
              .planNode();

      for (const auto barrierExecution : {false, true}) {
        SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution));
        auto task = AssertQueryBuilder(plan, duckDbQueryRunner_)
                        .splits(makeHiveConnectorSplits(tempFiles))
                        .serialExecution(true)
                        .barrierExecution(barrierExecution)
                        .config(
                            core::QueryConfig::kPreferredOutputBatchRows,
                            std::to_string(outputBatchSize))
                        .assertResults("SELECT distinct c0 FROM tmp");
        const auto taskStats = task->taskStats();
        ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0);
        ASSERT_EQ(taskStats.numFinishedSplits, numSplits);
        ASSERT_EQ(
            velox::exec::toPlanStats(taskStats)
                .at(aggregationNodeId)
                .outputVectors,
            expectedNumOuputBatches);
      }
    }
  }

  std::vector<RowVectorPtr> addPayload(
      const std::vector<VectorPtr>& keys,
      int numPayloadColumns) {
    std::vector<RowVectorPtr> data;
    vector_size_t totalSize = 0;
    for (const auto& keyVector : keys) {
      auto size = keyVector->size();
      auto payload = makeFlatVector<int32_t>(
          size, [totalSize](auto row) { return totalSize + row; });
      std::vector<VectorPtr> columns;
      columns.push_back(keyVector);
      for (int i = 0; i < numPayloadColumns; ++i) {
        columns.push_back(payload);
      }
      data.push_back(makeRowVector(columns));
      totalSize += size;
    }
    return data;
  }

  std::vector<RowVectorPtr> addPayload(const std::vector<RowVectorPtr>& keys) {
    auto numKeys = keys[0]->type()->size();

    std::vector<RowVectorPtr> data;

    vector_size_t totalSize = 0;
    for (const auto& keyVector : keys) {
      auto size = keyVector->size();
      auto payload = makeFlatVector<int32_t>(
          size, [totalSize](auto row) { return totalSize + row; });

      auto children = keyVector->as<RowVector>()->children();
      VELOX_CHECK_EQ(numKeys, children.size());
      children.push_back(payload);
      data.push_back(makeRowVector(children));
      totalSize += size;
    }
    return data;
  }

  size_t numKeys(const std::vector<RowVectorPtr>& keys) {
    return keys[0]->type()->size();
  }

  void testMultiKeyAggregation(
      const std::vector<RowVectorPtr>& keys,
      uint32_t outputBatchSize) {
    testMultiKeyAggregation(
        keys, keys[0]->type()->asRow().names(), outputBatchSize);
  }

  void testMultiKeyAggregation(
      const std::vector<RowVectorPtr>& keys,
      const std::vector<std::string>& preGroupedKeys,
      uint32_t outputBatchSize) {
    auto data = addPayload(keys);
    createDuckDbTable(data);

    auto plan =
        PlanBuilder()
            .values(data)
            .aggregation(
                keys[0]->type()->asRow().names(),
                preGroupedKeys,
                {"count(1)", "min(c1)", "max(c1)", "sum(c1)", "sumnonpod(1)"},
                {},
                core::AggregationNode::Step::kPartial,
                false)
            .finalAggregation()
            .planNode();

    // Generate a list of grouping keys to use in the query: c0, c1, c2,..
    std::ostringstream keySql;
    keySql << "c0";
    for (auto i = 1; i < numKeys(keys); i++) {
      keySql << ", c" << i;
    }

    const auto sql = fmt::format(
        "SELECT {}, count(1), min(c1), max(c1), sum(c1), sum(1) FROM tmp GROUP BY {}",
        keySql.str(),
        keySql.str());

    config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
        .assertResults(sql);

    EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);

    // Force partial aggregation flush after every batch of input.
    AssertQueryBuilder(plan, duckDbQueryRunner_)
        .config(core::QueryConfig::kMaxPartialAggregationMemory, "0")
        .assertResults(sql);

    EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);
  }

  void testMultiKeyDistinctAggregation(
      const std::vector<RowVectorPtr>& keys,
      uint32_t outputBatchSize) {
    auto data = addPayload(keys);
    createDuckDbTable(data);

    {
      auto plan =
          PlanBuilder()
              .values(data)
              .streamingAggregation(
                  keys[0]->type()->asRow().names(),
                  {"count(distinct c1)", "array_agg(c1)", "sumnonpod(1)"},
                  {},
                  core::AggregationNode::Step::kSingle,
                  false)
              .planNode();

      // Generate a list of grouping keys to use in the query: c0, c1, c2,..
      std::ostringstream keySql;
      keySql << "c0";
      for (auto i = 1; i < numKeys(keys); i++) {
        keySql << ", c" << i;
      }

      const auto sql = fmt::format(
          "SELECT {}, count(distinct c1), array_agg(c1), sum(1) FROM tmp GROUP BY {}",
          keySql.str(),
          keySql.str());

      config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
          .assertResults(sql);

      EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);
    }

    {
      auto plan = PlanBuilder()
                      .values(data)
                      .streamingAggregation(
                          keys[0]->type()->asRow().names(),
                          {},
                          {},
                          core::AggregationNode::Step::kSingle,
                          false)
                      .planNode();

      // Generate a list of grouping keys to use in the query: c0, c1, c2,..
      std::ostringstream keySql;
      keySql << "c0";
      for (auto i = 1; i < numKeys(keys); i++) {
        keySql << ", c" << i;
      }

      const auto sql = fmt::format("SELECT distinct {} FROM tmp", keySql.str());

      config(AssertQueryBuilder(plan, duckDbQueryRunner_), outputBatchSize)
          .assertResults(sql);
    }
  }

  void testMultiKeyDistinctAggregationWithBarrier(
      const std::vector<RowVectorPtr>& keys,
      uint32_t outputBatchSize) {
    const auto inputVectors = addPayload(keys);
    std::vector<std::shared_ptr<TempFilePath>> tempFiles;
    const int numSplits = keys.size();
    for (int32_t i = 0; i < numSplits; ++i) {
      tempFiles.push_back(TempFilePath::create());
    }
    writeToFiles(toFilePaths(tempFiles), inputVectors);

    createDuckDbTable(inputVectors);

    auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
    {
      core::PlanNodeId aggregationNodeId;
      auto plan =
          PlanBuilder(planNodeIdGenerator)
              .startTableScan()
              .outputType(
                  std::dynamic_pointer_cast<const RowType>(
                      inputVectors[0]->type()))
              .endTableScan()
              .streamingAggregation(
                  keys[0]->type()->asRow().names(),
                  {"count(distinct c1)", "array_agg(c1)", "sumnonpod(1)"},
                  {},
                  core::AggregationNode::Step::kSingle,
                  false)
              .capturePlanNodeId(aggregationNodeId)
              .planNode();

      // Generate a list of grouping keys to use in the query: c0, c1, c2,..
      std::ostringstream keySql;
      keySql << "c0";
      for (auto i = 1; i < numKeys(keys); i++) {
        keySql << ", c" << i;
      }

      const auto sql = fmt::format(
          "SELECT {}, count(distinct c1), array_agg(c1), sum(1) FROM tmp GROUP BY {}",
          keySql.str(),
          keySql.str());

      for (const auto barrierExecution : {false, true}) {
        SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution));
        auto task = AssertQueryBuilder(plan, duckDbQueryRunner_)
                        .splits(makeHiveConnectorSplits(tempFiles))
                        .serialExecution(true)
                        .barrierExecution(barrierExecution)
                        .config(
                            core::QueryConfig::kPreferredOutputBatchRows,
                            std::to_string(outputBatchSize))
                        .assertResults(sql);
        const auto taskStats = task->taskStats();
        ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0);
        ASSERT_EQ(taskStats.numFinishedSplits, numSplits);
        EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);
      }
    }

    {
      core::PlanNodeId aggregationNodeId;
      auto plan = PlanBuilder(planNodeIdGenerator)
                      .startTableScan()
                      .outputType(
                          std::dynamic_pointer_cast<const RowType>(
                              inputVectors[0]->type()))
                      .endTableScan()
                      .streamingAggregation(
                          keys[0]->type()->asRow().names(),
                          {},
                          {},
                          core::AggregationNode::Step::kSingle,
                          false)
                      .capturePlanNodeId(aggregationNodeId)
                      .planNode();

      // Generate a list of grouping keys to use in the query: c0, c1, c2,..
      std::ostringstream keySql;
      keySql << "c0";
      for (auto i = 1; i < numKeys(keys); i++) {
        keySql << ", c" << i;
      }

      const auto sql = fmt::format("SELECT distinct {} FROM tmp", keySql.str());

      for (const auto barrierExecution : {false, true}) {
        SCOPED_TRACE(fmt::format("barrierExecution {}", barrierExecution));
        auto task = AssertQueryBuilder(plan, duckDbQueryRunner_)
                        .splits(makeHiveConnectorSplits(tempFiles))
                        .serialExecution(true)
                        .barrierExecution(barrierExecution)
                        .config(
                            core::QueryConfig::kPreferredOutputBatchRows,
                            std::to_string(outputBatchSize))
                        .assertResults(sql);
        const auto taskStats = task->taskStats();
        ASSERT_EQ(taskStats.numBarriers, barrierExecution ? numSplits : 0);
        ASSERT_EQ(taskStats.numFinishedSplits, numSplits);
        EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed);
      }
    }
  }
};

VELOX_INSTANTIATE_TEST_SUITE_P(
    StreamingAggregationTest,
    StreamingAggregationTest,
    testing::Values(
        TestParams{0, 1},
        TestParams{0, 1024},
        TestParams{0, std::numeric_limits<uint64_t>::max()},
        TestParams{1, 1},
        TestParams{1, 1024},
        TestParams{1, std::numeric_limits<uint64_t>::max()},
        TestParams{64, 1},
        TestParams{64, 1024},
        TestParams{64, std::numeric_limits<uint64_t>::max()},
        TestParams{std::numeric_limits<int32_t>::max(), 1},
        TestParams{std::numeric_limits<int32_t>::max(), 1024},
        TestParams{
            std::numeric_limits<int32_t>::max(),
            std::numeric_limits<uint64_t>::max()}),
    [](const testing::TestParamInfo<TestParams>& info) {
      return fmt::format(
          "streamingMinOutputBatchSize_{}_preferredOutputBatchBytes_{}",
          info.param.streamingMinOutputBatchSize ==
                  std::numeric_limits<int32_t>::max()
              ? "inf"
              : std::to_string(info.param.streamingMinOutputBatchSize),
          info.param.preferredOutputBatchBytes ==
                  std::numeric_limits<uint64_t>::max()
              ? "inf"
              : std::to_string(info.param.preferredOutputBatchBytes));
    });

TEST_P(StreamingAggregationTest, smallInputBatches) {
  // Use grouping keys that span one or more batches.
  std::vector<VectorPtr> keys = {
      makeNullableFlatVector<int32_t>({1, 1, std::nullopt, 2, 2}),
      makeFlatVector<int32_t>({2, 3, 3, 4}),
      makeFlatVector<int32_t>({5, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 7, 8}),
  };

  testAggregation(keys, 1024);

  // Cut output into tiny batches of size 3.
  testAggregation(keys, 3);
}

TEST_P(StreamingAggregationTest, multipleKeys) {
  std::vector<RowVectorPtr> keys = {
      makeRowVector({
          makeFlatVector<int32_t>({1, 1, 2, 2, 2}),
          makeFlatVector<int64_t>({10, 20, 20, 30, 30}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>({2, 3, 3, 3, 4}),
          makeFlatVector<int64_t>({30, 30, 40, 40, 40}),
      }),
      makeRowVector({
          makeNullableFlatVector<int32_t>({5, std::nullopt, 6, 6, 6}),
          makeNullableFlatVector<int64_t>({40, 50, 50, 50, std::nullopt}),
      }),
  };

  testMultiKeyAggregation(keys, 1024);

  // Cut output into tiny batches of size 3.
  testMultiKeyAggregation(keys, 3);
}

TEST_P(StreamingAggregationTest, regularSizeInputBatches) {
  auto size = 1'024;

  std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>(size, [](auto row) { return row / 5; }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return (size + row) / 5; }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return (2 * size + row) / 5; }),
      makeFlatVector<int32_t>(
          78, [size](auto row) { return (3 * size + row) / 5; }),
  };

  testAggregation(keys, 1024);

  // Cut output into small batches of size 100.
  testAggregation(keys, 100);
}

TEST_P(StreamingAggregationTest, uniqueKeys) {
  auto size = 1'024;

  std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>(size, [](auto row) { return row; }),
      makeFlatVector<int32_t>(size, [size](auto row) { return (size + row); }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return 2 * size + row; }),
      makeFlatVector<int32_t>(78, [size](auto row) { return 3 * size + row; }),
  };

  testAggregation(keys, 1024);

  // Cut output into small batches of size 100.
  testAggregation(keys, 100);
}

TEST_P(StreamingAggregationTest, partialStreaming) {
  auto size = 1'024;

  // Generate 2 key columns. First key is clustered / pre-grouped. Second key
  // is not. Make one value of the clustered key last for exactly one batch,
  // another value span two bathes.
  auto keys = {
      makeRowVector({
          makeFlatVector<int32_t>({-10, -10, -5, -5, -5}),
          makeFlatVector<int32_t>({0, 1, 2, 1, 4}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>({-5, -5, -4, -3, -2}),
          makeFlatVector<int32_t>({0, 1, 2, 1, 4}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>({-1, -1, -1, -1, -1}),
          makeFlatVector<int32_t>({0, 1, 2, 1, 4}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>({0, 0, 0, 0, 0}),
          makeFlatVector<int32_t>({0, 4, 2, 3, 4}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>(size, [](auto row) { return row / 7; }),
          makeFlatVector<int32_t>(size, [](auto row) { return row % 5; }),
      }),
      makeRowVector({
          makeFlatVector<int32_t>(
              size, [&](auto row) { return (size + row) / 7; }),
          makeFlatVector<int32_t>(
              size, [&](auto row) { return (size + row) % 5; }),
      }),
  };

  testMultiKeyAggregation(keys, {"c0"}, 1024);
}

// Test StreamingAggregation being closed without being initialized. Create a
// pipeline with Project followed by StreamingAggregation. Make
// Project::initialize fail by using non-existent function.
TEST_P(StreamingAggregationTest, closeUninitialized) {
  auto data = makeRowVector({
      makeFlatVector<int64_t>({1, 2, 3}),
  });
  auto plan = PlanBuilder()
                  .values({data})
                  .addNode([](auto nodeId, auto source) -> core::PlanNodePtr {
                    return std::make_shared<core::ProjectNode>(
                        nodeId,
                        std::vector<std::string>{"c0", "x"},
                        std::vector<core::TypedExprPtr>{
                            std::make_shared<core::FieldAccessTypedExpr>(
                                BIGINT(), "c0"),
                            std::make_shared<core::CallTypedExpr>(
                                BIGINT(), "do-not-exist")},
                        source);
                  })
                  .partialStreamingAggregation({"c0"}, {"sum(x)"})
                  .planNode();

  VELOX_ASSERT_THROW(
      AssertQueryBuilder(plan).copyResults(pool()),
      "Scalar function name not registered: do-not-exist");
}

TEST_P(StreamingAggregationTest, sortedAggregations) {
  const auto size = 512;
  std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>(size, [](auto row) { return row; }),
      makeFlatVector<int32_t>(size, [size](auto row) { return (size + row); }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return (2 * size + row); }),
      makeFlatVector<int32_t>(
          78, [size](auto row) { return (3 * size + row); }),
  };

  testSortedAggregation(keys, 512);
  testSortedAggregation(keys, 32);
}

TEST_P(StreamingAggregationTest, distinctAggregations) {
  auto size = 1024;

  std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>(size, [](auto row) { return row; }),
      makeFlatVector<int32_t>(size, [size](auto row) { return (size + row); }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return (2 * size + row); }),
      makeFlatVector<int32_t>(
          78, [size](auto row) { return (3 * size + row); }),
  };

  testDistinctAggregation(keys, 1024);
  testDistinctAggregation(keys, 32);

  std::vector<RowVectorPtr> multiKeys = {
      makeRowVector({
          makeFlatVector<int32_t>({1, 1, 2, 2, 2}),
          makeFlatVector<int64_t>({10, 20, 20, 30, 30}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>({2, 3, 3, 3, 4}),
          makeFlatVector<int64_t>({30, 30, 40, 40, 40}),
      }),
      makeRowVector({
          makeNullableFlatVector<int32_t>({5, 5, 6, 6, 6}),
          makeNullableFlatVector<int64_t>({40, 50, 50, 50, 50}),
      }),
  };

  testMultiKeyDistinctAggregation(multiKeys, 1024);
  testMultiKeyDistinctAggregation(multiKeys, 3);
}

TEST_P(StreamingAggregationTest, clusteredInput) {
  std::vector<VectorPtr> keys = {
      makeNullableFlatVector<int32_t>({1, 1, std::nullopt, 2, 2}),
      makeFlatVector<int32_t>({2, 3, 3, 4}),
      makeFlatVector<int32_t>({5, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 7, 8}),
  };
  auto data = addPayload(keys, 1);
  auto plan = PlanBuilder()
                  .values(data)
                  .partialStreamingAggregation(
                      {"c0"}, {"count(c1)", "arbitrary(c1)", "array_agg(c1)"})
                  .finalAggregation()
                  .planNode();
  auto expected = makeRowVector({
      makeNullableFlatVector<int32_t>({1, std::nullopt, 2, 3, 4, 5, 6, 7, 8}),
      makeFlatVector<int64_t>({2, 1, 3, 2, 1, 1, 8, 1, 1}),
      makeFlatVector<int32_t>({0, 2, 3, 6, 8, 9, 10, 18, 19}),
      makeArrayVector<int32_t>(
          {{0, 1},
           {2},
           {3, 4, 5},
           {6, 7},
           {8},
           {9},
           {10, 11, 12, 13, 14, 15, 16, 17},
           {18},
           {19}}),
  });
  for (auto batchSize : {3, 20}) {
    SCOPED_TRACE(fmt::format("batchSize={}", batchSize));
    config(AssertQueryBuilder(plan), batchSize).assertResults(expected);
  }
}

TEST_P(StreamingAggregationTest, clusteredInputWithOutputSplit) {
  std::vector<VectorPtr> keysWithOverlap = {
      makeNullableFlatVector<int32_t>({1, 1, std::nullopt, 2, 2}),
      makeFlatVector<int32_t>({2, 3, 3, 4}),
      makeFlatVector<int32_t>({5, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 7, 8}),
  };
  auto dataWithOverlap = addPayload(keysWithOverlap, 1);
  auto planWithOverlap = PlanBuilder()
                             .values(dataWithOverlap)
                             .streamingAggregation(
                                 {"c0"},
                                 {"arbitrary(c1)", "array_agg(c1)"},
                                 {},
                                 core::AggregationNode::Step::kSingle,
                                 false)
                             .planNode();
  const auto expectedWithOverlap = makeRowVector({
      makeNullableFlatVector<int32_t>({1, std::nullopt, 2, 3, 4, 5, 6, 7, 8}),
      makeFlatVector<int32_t>({0, 2, 3, 6, 8, 9, 10, 18, 19}),
      makeArrayVector<int32_t>(
          {{0, 1},
           {2},
           {3, 4, 5},
           {6, 7},
           {8},
           {9},
           {10, 11, 12, 13, 14, 15, 16, 17},
           {18},
           {19}}),
  });
  for (auto batchSize : {1, 3, 20}) {
    SCOPED_TRACE(fmt::format("batchSize={}", batchSize));
    config(AssertQueryBuilder(planWithOverlap), batchSize)
        .assertResults(expectedWithOverlap);
  }

  std::vector<VectorPtr> keysWithoutOverlap = {
      makeNullableFlatVector<int32_t>({1, 1, std::nullopt, 2, 2}),
      makeFlatVector<int32_t>({3, 3, 4, 4}),
      makeFlatVector<int32_t>({5, 6, 6, 7}),
      makeFlatVector<int32_t>({8, 8, 9, 9}),
      makeFlatVector<int32_t>({10, 11, 12}),
  };
  auto dataWithoutOverlap = addPayload(keysWithoutOverlap, 1);
  auto planWithoutOverlap = PlanBuilder()
                                .values(dataWithoutOverlap)
                                .streamingAggregation(
                                    {"c0"},
                                    {"arbitrary(c1)", "array_agg(c1)"},
                                    {},
                                    core::AggregationNode::Step::kSingle,
                                    false)
                                .planNode();
  const auto expectedWithoutOverlap = makeRowVector(
      {makeNullableFlatVector<int32_t>(
           {1, std::nullopt, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
       makeFlatVector<int32_t>({0, 2, 3, 5, 7, 9, 10, 12, 13, 15, 17, 18, 19}),
       makeArrayVector<int32_t>(
           {{0, 1},
            {2},
            {3, 4},
            {5, 6},
            {7, 8},
            {9},
            {10, 11},
            {12},
            {13, 14},
            {15, 16},
            {17},
            {18},
            {19}})});
  for (auto batchSize : {1, 3, 20}) {
    SCOPED_TRACE(fmt::format("batchSize={}", batchSize));
    config(AssertQueryBuilder(planWithoutOverlap), batchSize)
        .assertResults(expectedWithoutOverlap);
  }

  std::vector<VectorPtr> mixedKeys = {
      makeNullableFlatVector<int32_t>({1, 1, std::nullopt, std::nullopt, 2}),
      makeFlatVector<int32_t>({3, 3, 4, 4}),
      makeFlatVector<int32_t>({6, 6, 6, 7}),
      makeFlatVector<int32_t>({7, 8, 9, 9}),
      makeFlatVector<int32_t>({10, 11, 12}),
  };
  auto mixedData = addPayload(mixedKeys, 1);
  auto mixedPlan = PlanBuilder()
                       .values(mixedData)
                       .streamingAggregation(
                           {"c0"},
                           {"arbitrary(c1)", "array_agg(c1)"},
                           {},
                           core::AggregationNode::Step::kSingle,
                           false)
                       .planNode();
  const auto expectedMixedResult = makeRowVector(
      {makeNullableFlatVector<int32_t>(
           {1, std::nullopt, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12}),
       makeFlatVector<int32_t>({0, 2, 4, 5, 7, 9, 12, 14, 15, 17, 18, 19}),
       makeArrayVector<int32_t>(
           {{0, 1},
            {2, 3},
            {4},
            {5, 6},
            {7, 8},
            {9, 10, 11},
            {12, 13},
            {14},
            {15, 16},
            {17},
            {18},
            {19}})});
  for (auto batchSize : {1, 3, 20}) {
    SCOPED_TRACE(fmt::format("batchSize={}", batchSize));
    config(AssertQueryBuilder(mixedPlan), batchSize)
        .assertResults(expectedMixedResult);
  }
}

TEST_P(StreamingAggregationTest, clusteredInputWithNulls) {
  std::vector<VectorPtr> keyVectors = {
      makeFlatVector<int32_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 3}),
      makeFlatVector<int32_t>({4, 4, 4, 4, 5, 5, 5, 5, 6, 6}),
      makeFlatVector<int32_t>({7, 7, 7, 8}),
      makeFlatVector<int32_t>({8, 8, 8, 9, 9, 9, 10, 10, 10}),
      makeFlatVector<int32_t>({11, 11, 11}),
  };
  std::vector<VectorPtr> dataVectors = {
      makeRowVector(
          {makeFlatVector<int32_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 3}),
           makeFlatVector<int32_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 3})},
          [](auto row) { return row < 3; }),
      makeRowVector(
          {makeFlatVector<int32_t>({4, 4, 4, 4, 5, 5, 5, 5, 6, 6}),
           makeFlatVector<int32_t>({4, 4, 4, 4, 5, 5, 5, 5, 6, 6})},
          [](auto row) { return row < 4 || row > 7; }),

      makeRowVector(
          {makeFlatVector<int32_t>({7, 7, 7, 8}),
           makeFlatVector<int32_t>({7, 7, 7, 8})},
          [](auto row) { return row > 2; }),

      makeRowVector(
          {makeFlatVector<int32_t>({8, 8, 8, 9, 9, 9, 10, 10, 10}),
           makeFlatVector<int32_t>({8, 8, 8, 9, 9, 9, 10, 10, 10})},
          [](auto row) { return row < 3; }),

      makeRowVector(
          {makeFlatVector<int32_t>({11, 11, 11}),
           makeFlatVector<int32_t>({11, 11, 11})},
          [](auto /*unused*/) { return true; })};
  ASSERT_EQ(keyVectors.size(), dataVectors.size());
  std::vector<RowVectorPtr> rowVectors;
  for (int i = 0; i < keyVectors.size(); ++i) {
    rowVectors.emplace_back(makeRowVector({keyVectors[i], dataVectors[i]}));
  }

  const auto plan =
      PlanBuilder()
          .values(rowVectors)
          .partialStreamingAggregation({"c0"}, {"count(c1)", "arbitrary(c1)"})
          .finalAggregation()
          .planNode();

  const auto expected = makeRowVector(
      {makeNullableFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}),
       makeFlatVector<int64_t>({0, 3, 4, 0, 4, 0, 3, 0, 3, 3, 0}),
       makeRowVector(
           {makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}),
            makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})},
           [](auto row) {
             if (row == 0 || row == 3 || row == 5 || row == 7 || row == 10) {
               return true;
             }
             return false;
           })});
  for (auto batchSize : {20}) {
    SCOPED_TRACE(fmt::format("batchSize={}", batchSize));
    config(AssertQueryBuilder(plan), batchSize).assertResults(expected);
  }
}

TEST_P(StreamingAggregationTest, sortedAggregationsWithBarrier) {
  const auto size = 1024;
  const std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>(size, [](auto row) { return row; }),
      makeFlatVector<int32_t>(size, [size](auto row) { return (size + row); }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return (2 * size + row); }),
      makeFlatVector<int32_t>(
          78, [size](auto row) { return (3 * size + row); }),
  };

  testSortedAggregationWithBarrier(keys, 1024, 4);
  testSortedAggregationWithBarrier(keys, 32, 4);
}

TEST_P(StreamingAggregationTest, clusteredInputWithBarrier) {
  const std::vector<VectorPtr> keys = {
      makeNullableFlatVector<int32_t>({1, 2, std::nullopt, 3, 4}),
      makeFlatVector<int32_t>({9, 10, 11, 12}),
      makeFlatVector<int32_t>({17, 18, 19}),
  };
  auto inputVectors = addPayload(keys, 1);
  std::vector<std::shared_ptr<TempFilePath>> tempFiles;
  const int numSplits = keys.size();
  for (int32_t i = 0; i < numSplits; ++i) {
    tempFiles.push_back(TempFilePath::create());
  }
  writeToFiles(toFilePaths(tempFiles), inputVectors);

  auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
  core::PlanNodeId streamingAggregationNodeId;
  auto plan =
      PlanBuilder(planNodeIdGenerator)
          .startTableScan()
          .outputType(
              std::dynamic_pointer_cast<const RowType>(inputVectors[0]->type()))
          .endTableScan()
          .partialStreamingAggregation(
              {"c0"}, {"count(c1)", "arbitrary(c1)", "array_agg(c1)"})
          .capturePlanNodeId(streamingAggregationNodeId)
          .finalAggregation()
          .planNode();
  const auto expected = makeRowVector(
      {makeNullableFlatVector<int32_t>(
           {1, 2, std::nullopt, 3, 4, 9, 10, 11, 12, 17, 18, 19}),
       makeFlatVector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}),
       makeFlatVector<int32_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}),
       makeArrayVector<int32_t>(
           {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}, {11}})});
  struct {
    int batchSize;
    bool barrierExecution;
    int numExpectedOutputBatches;

    std::string debugString() const {
      return fmt::format(
          "batchSize={}, barrierExecution={}, numExpectedOutputBatches={}",
          batchSize,
          barrierExecution,
          numExpectedOutputBatches);
    }
  } testSettings[] = {
      {3, true, 3}, {3, false, 3}, {20, true, 3}, {20, false, 1}};

  for (const auto& testData : testSettings) {
    SCOPED_TRACE(testData.debugString());

    auto task = AssertQueryBuilder(plan, duckDbQueryRunner_)
                    .splits(makeHiveConnectorSplits(tempFiles))
                    .serialExecution(true)
                    .barrierExecution(testData.barrierExecution)
                    .config(
                        core::QueryConfig::kPreferredOutputBatchRows,
                        std::to_string(testData.batchSize))
                    .assertResults(expected);
    const auto taskStats = task->taskStats();
    ASSERT_EQ(taskStats.numBarriers, testData.barrierExecution ? numSplits : 0);
    ASSERT_EQ(taskStats.numFinishedSplits, numSplits);
    ASSERT_EQ(
        velox::exec::toPlanStats(taskStats)
            .at(streamingAggregationNodeId)
            .outputVectors,
        testData.numExpectedOutputBatches);
  }
}

TEST_P(StreamingAggregationTest, distinctAggregationsWithBarrier) {
  const auto size = 1024;
  const std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>(size, [](auto row) { return row; }),
      makeFlatVector<int32_t>(size, [size](auto row) { return (size + row); }),
      makeFlatVector<int32_t>(
          size, [size](auto row) { return (2 * size + row); }),
      makeFlatVector<int32_t>(
          78, [size](auto row) { return (3 * size + row); }),
  };

  testDistinctAggregationWithBarrier(keys, 1024, 4);
  testDistinctAggregationWithBarrier(keys, 32, 4);

  std::vector<RowVectorPtr> multiKeys = {
      makeRowVector({
          makeFlatVector<int32_t>({1, 1, 2, 2, 2}),
          makeFlatVector<int64_t>({10, 20, 20, 30, 40}),
      }),
      makeRowVector({
          makeFlatVector<int32_t>({3, 3, 3, 3, 4}),
          makeFlatVector<int64_t>({30, 40, 50, 60, 40}),
      }),
      makeRowVector({
          makeNullableFlatVector<int32_t>({5, 5, 6, 6, 6}),
          makeNullableFlatVector<int64_t>({40, 50, 60, 70, 80}),
      }),
  };
  testMultiKeyDistinctAggregationWithBarrier(multiKeys, 1024);
  testMultiKeyDistinctAggregationWithBarrier(multiKeys, 3);
}

TEST_P(StreamingAggregationTest, constantInput) {
  auto data = makeRowVector({makeFlatVector<int32_t>({1, 1, 2, 2})});
  auto plan = PlanBuilder()
                  .values({data})
                  .partialStreamingAggregation({"c0"}, {"array_agg(3)"})
                  .finalAggregation()
                  .planNode();
  auto expected = makeRowVector({
      makeFlatVector<int32_t>({1, 2}),
      makeArrayVector<int64_t>({{3, 3}, {3, 3}}),
  });
  config(AssertQueryBuilder(plan), 1).assertResults(expected);
  config(AssertQueryBuilder(plan), 10).assertResults(expected);
}

TEST_P(StreamingAggregationTest, preferredOutputBatchBytes) {
  // Use grouping keys that span one or more batches.
  std::vector<VectorPtr> keys = {
      makeNullableFlatVector<int32_t>({1, 1, std::nullopt, 2, 2}),
      makeFlatVector<int32_t>({2, 3, 3, 4}),
      makeFlatVector<int32_t>({5, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 6, 6, 6}),
      makeFlatVector<int32_t>({6, 7, 8}),
  };

  auto data = addPayload(keys, 1);

  auto plan = PlanBuilder()
                  .values(data)
                  .partialStreamingAggregation(
                      {"c0"},
                      {"count(1)",
                       "min(c1)",
                       "max(c1)",
                       "sum(c1)",
                       "sumnonpod(1)",
                       "sum(cast(NULL as INT))"})
                  .finalAggregation()
                  .planNode();

  auto results =
      config(AssertQueryBuilder(plan), 1024).copyResultBatches(pool_.get());

  // If streamingMinOutputBatchSize is set to 1, we expect an output batch for:
  // {1, NULL}, {2}, {3, 4}, {5}, {6}, {7, 8}.
  // Otherwise, we expect the output batches to be determined by
  // preferredOutputBatchBytes.
  size_t expectedOutputBatches;
  if (GetParam().streamingMinOutputBatchSize == 1) {
    expectedOutputBatches = 6;
  } else if (GetParam().preferredOutputBatchBytes == 1) {
    expectedOutputBatches = 5;
  } else if (GetParam().preferredOutputBatchBytes == 1024) {
    expectedOutputBatches = 2;
  } else {
    ASSERT_EQ(
        GetParam().preferredOutputBatchBytes,
        std::numeric_limits<uint64_t>::max());
    expectedOutputBatches = 1;
  }

  ASSERT_EQ(results.size(), expectedOutputBatches);
}

// Tests that when noGroupsSpanBatches is set, the number of output batches
// matches the number of input batches when minOutputBatchRows is set to 1.
// When minOutputBatchRows is set to an extremely large value, we expect a
// single output batch.
TEST_F(StreamingAggregationTest, noGroupsSpanBatches) {
  // Create input batches where no group spans across batches.
  // Each batch has unique grouping keys that don't appear in other batches.
  std::vector<VectorPtr> keys = {
      makeFlatVector<int32_t>({1, 1, 2, 2}),
      makeFlatVector<int32_t>({3, 3, 4, 4}),
      makeFlatVector<int32_t>({5, 5, 6, 6}),
      makeFlatVector<int32_t>({7, 7, 8, 8}),
      makeFlatVector<int32_t>({9, 9, 10, 10}),
  };

  auto data = addPayload(keys, 1);
  createDuckDbTable(data);

  struct {
    int32_t minOutputBatchRows;
    size_t expectedOutputBatches;

    std::string debugString() const {
      return fmt::format(
          "minOutputBatchRows={}, expectedOutputBatches={}",
          minOutputBatchRows,
          expectedOutputBatches);
    }
  } testSettings[] = {
      // When minOutputBatchRows is 1, each input batch produces an output batch
      {1, keys.size()},
      // When minOutputBatchRows is very large, all groups are batched together
      // into a single output
      {std::numeric_limits<int32_t>::max(), 1},
  };

  for (const auto& testData : testSettings) {
    SCOPED_TRACE(testData.debugString());

    auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
    core::PlanNodeId aggregationNodeId;
    auto plan = PlanBuilder(planNodeIdGenerator)
                    .values(data)
                    .streamingAggregation(
                        {"c0"},
                        {"count(1)", "sum(c1)"},
                        {},
                        core::AggregationNode::Step::kSingle,
                        /*ignoreNullKeys=*/false,
                        /*noGroupsSpanBatches=*/true)
                    .capturePlanNodeId(aggregationNodeId)
                    .planNode();

    auto task =
        AssertQueryBuilder(plan, duckDbQueryRunner_)
            .config(
                core::QueryConfig::kStreamingAggregationMinOutputBatchRows,
                std::to_string(testData.minOutputBatchRows))
            .assertResults("SELECT c0, count(1), sum(c1) FROM tmp GROUP BY c0");

    // Verify the number of output batches.
    const auto taskStats = task->taskStats();
    ASSERT_EQ(
        velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors,
        testData.expectedOutputBatches);
  }
}

namespace {
class InputSourceNode : public core::PlanNode {
 public:
  InputSourceNode(
      const core::PlanNodeId& id,
      int numInitialInputCalls,
      int numSkipInputCalls,
      core::PlanNodePtr source)
      : PlanNode{id},
        numInitialInputCalls_(numInitialInputCalls),
        numSkipInputCalls_(numSkipInputCalls),
        sources_{std::move(source)} {}

  const RowTypePtr& outputType() const override {
    return sources_[0]->outputType();
  }

  const std::vector<core::PlanNodePtr>& sources() const override {
    return sources_;
  }

  int numInitialInputCalls() const {
    return numInitialInputCalls_;
  }

  int numSkipInputCalls() const {
    return numSkipInputCalls_;
  }

  std::string_view name() const override {
    return "external blocking node";
  }

 private:
  void addDetails(std::stringstream& /* stream */) const override {}

  const int numInitialInputCalls_;
  const int numSkipInputCalls_;
  const std::vector<core::PlanNodePtr> sources_;
};

class InputSourceOperator : public exec::Operator {
 public:
  InputSourceOperator(
      int32_t operatorId,
      exec::DriverCtx* driverCtx,
      std::shared_ptr<const InputSourceNode> node)
      : Operator(
            driverCtx,
            node->outputType(),
            operatorId,
            node->id(),
            "InputSource"),
        numInitialInputCalls_(node->numInitialInputCalls()),
        numSkipInputCalls_(node->numSkipInputCalls()) {}

  bool needsInput() const override {
    if (numInitialInputCalls_-- >= 0) {
      return true;
    }
    if (numSkipInputCalls_-- >= 0) {
      return false;
    }
    return !noMoreInput_;
  }

  void addInput(RowVectorPtr input) override {
    input_ = std::move(input);
  }

  RowVectorPtr getOutput() override {
    auto output = std::move(input_);
    input_ = nullptr;
    return output;
  }

  exec::BlockingReason isBlocked(ContinueFuture* future) override {
    return exec::BlockingReason::kNotBlocked;
  }

  bool isFinished() override {
    return noMoreInput_;
  }

 private:
  mutable int numInitialInputCalls_{0};
  mutable int numSkipInputCalls_{0};
  RowVectorPtr input_;
};

class SourceNodeTranslator : public exec::Operator::PlanNodeTranslator {
  std::unique_ptr<exec::Operator> toOperator(
      exec::DriverCtx* ctx,
      int32_t id,
      const core::PlanNodePtr& node) override {
    if (auto castedNode =
            std::dynamic_pointer_cast<const InputSourceNode>(node)) {
      return std::make_unique<InputSourceOperator>(id, ctx, castedNode);
    }
    return nullptr;
  }
};
} // namespace

TEST_P(StreamingAggregationTest, needsInputWhenSplitOutput) {
  exec::Operator::registerOperator(std::make_unique<SourceNodeTranslator>());
  const auto size = 32;
  const auto numBatches{5};
  std::vector<RowVectorPtr> batches;
  for (int i = 0; i < numBatches; ++i) {
    batches.push_back(makeRowVector(
        {makeFlatVector<int32_t>(
             size,
             [i, size](auto row) {
               return row == 0 ? i * size + row - 1 : i * size + row;
             }),
         makeFlatVector<int32_t>(size, [](auto row) { return row; })}));
  }
  createDuckDbTable(batches);

  auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
  core::PlanNodeId aggregationNodeId;
  auto plan = PlanBuilder(planNodeIdGenerator)
                  .values(batches)
                  .streamingAggregation(
                      {"c0"},
                      {"array_agg(c1)"},
                      {},
                      core::AggregationNode::Step::kSingle,
                      false)
                  .addNode([](std::string id, core::PlanNodePtr input) mutable {
                    return std::make_shared<InputSourceNode>(
                        id, 2, numBatches - 2, input);
                  })
                  .project({"c0", "a0"})
                  .capturePlanNodeId(aggregationNodeId)
                  .planNode();

  auto task =
      AssertQueryBuilder(plan, duckDbQueryRunner_)
          .serialExecution(true)
          .config(
              core::QueryConfig::kStreamingAggregationMinOutputBatchRows, "1")
          .assertResults("SELECT c0, array_agg(c1) FROM tmp GROUP BY c0");
  const auto taskStats = task->taskStats();
  ASSERT_EQ(
      velox::exec::toPlanStats(taskStats).at(aggregationNodeId).outputVectors,
      9);
}
} // namespace
} // namespace facebook::velox::exec
